import lowp
from fairseq.models import register_model, register_model_architecture
from fairseq.models.roberta import RobertaModel, base_architecture


@register_model("roberta_lowp")
class RobertaLowp(RobertaModel):
    def __init__(self, *kargs, **kwargs):
        super(RobertaLowp, self).__init__(*kargs, **kwargs)

    @staticmethod
    def add_args(parser):
        super(RobertaLowp, RobertaLowp).add_args(parser)
        parser.add_argument('--precision', type=str, default='BF16',
                            help='precision of lowp. default=BF16')
        parser.add_argument('--warn-patched', action='store_true',
                            help='warn on lowp patched functions')
        parser.add_argument('--warn-not-patched', action='store_true',
                            help='warn on lowp non-patched functions')

    @classmethod
    def build_model(cls, args, task):
        # set any default arguments
        roberta_lowp(args)
        return super(RobertaLowp, RobertaLowp).build_model(args, task)

    def forward(self, *kargs, **kwargs):
        with lowp.Lowp(mode=self.args.precision,
                       warn_patched=self.args.warn_patched,
                       warn_not_patched=self.args.warn_not_patched):
            return super(RobertaLowp, self).forward(*kargs, **kwargs)


@register_model_architecture("roberta_lowp", "roberta_lowp")
def roberta_lowp(args):
    args.warn_patched = getattr(args, "warn_patched", False)
    args.warn_not_patched = getattr(args, "warn_not_patched", False)
    base_architecture(args)
